classdef accelerationCurve
    properties
        spdBrk (1,:) = [0 25 60 120]/3.6; % m/s
        accVal (1,:) = [1, 1, 0.5, 0.5]; % m/s^2
    end
    properties (SetAccess=immutable, Hidden=true)
        b
        c
        accDistSeg
    end
    properties (SetAccess=protected, Hidden=true)
        sfun = @(b, c, v1, v2) -1 ./ b.^2 .* ( b .* v1 - b.*v2 + c.*log( ( c + b.*v2 ) ./ ( c + b.*v1 ) ) );
    end

    methods
        function obj = accelerationCurve(spdBrk, accVal)
            obj.spdBrk = spdBrk;
            obj.accVal = accVal;

            % Slope and intercept
            obj.b = ( accVal(2:end) - accVal(1:end-1) ) ./ ( spdBrk(2:end) - spdBrk(1:end-1) );
            obj.c = accVal(1:end-1) - spdBrk(1:end-1) .* obj.b;

            % Acceleration distances
            spdLeft = obj.spdBrk(1:end-1);
            spdRight = obj.spdBrk(2:end);

            obj.accDistSeg = obj.sfun( obj.b, obj.c, spdLeft, spdRight );
            % obj.sfun is only valid when slope ~= 0
            idx = obj.accVal(1:end-1) == obj.accVal(2:end);
            obj.accDistSeg(idx) = (spdRight(idx).^2 - spdLeft(idx).^2) ./ ( 2 * obj.c(idx) );

        end

        function val = eval(obj, spd)
            arguments
                obj
                spd (1,:)
            end

            % Find interval
            [~, idx] = min(spd(:) >= obj.spdBrk, [], 2);
            idx = idx - 1;
            extrapRight = spd(:) >= obj.spdBrk(end);
            idx(extrapRight) = length(obj.spdBrk) - 1;
            extrapLeft = spd(:) < obj.spdBrk(1);
            idx(extrapLeft) = 1;

            % Interpolate
            val = obj.b(idx) .* spd + obj.c(idx);

            % Handle case where spd is one of the breakpoints
            idx = spd(:) == obj.spdBrk;
            row = any(idx, 2);
            col = sum(cumsum(idx(:,end:-1:1), 2), 2);
            col = nonzeros(col);
            val(row) = obj.accVal(col);

            % abs(x(idx+1) - x(idx)) < eps(max(abs(x)))
        end

        function accDist = accelDist(obj, spd1, spd2)

            if length(spd1) ~= length(spd2)
                spd1 = spd1 + zeros(size(spd2));
                spd2 = spd2 + zeros(size(spd1));
            end
            if any([spd1(:); spd2(:)] < obj.spdBrk(1)) || any([spd1(:); spd2(:)] > obj.spdBrk(end))
                error("Extrapolation not supported.")
            end

            % If spd1 > spd2, swap them
            spd1temp = spd1;
            spd1( spd1 > spd2 ) = spd2( spd1 > spd2 );
            spd2( spd1temp > spd2 ) = spd1temp( spd1temp > spd2 );

            accDist = zeros(size(spd1));
            spdLeft = obj.spdBrk(1:end-1);
            spdRight = obj.spdBrk(2:end);

            for n = 1:( length(obj.spdBrk) - 1 )
                % Initialize
                accDistAdd = zeros(size(spd1));

                % If both v1 and v2 are in the current interval: add the
                % distance from v1 to v2
                cond = spd1 >= spdLeft(n) & spd1 < spdRight(n) & ...
                    spd2 >= spdLeft(n) & spd2 < spdRight(n);
                if obj.accVal(n) == obj.accVal(n+1)
                    accDistAdd(cond) = accDistAdd(cond) + ( spd2(cond).^2 - spd1(cond).^2 ) ./ ( 2 * obj.c(n) );
                else
                    accDistAdd(cond) = accDistAdd(cond) + obj.sfun( obj.b(n), obj.c(n), spd1(cond), spd2(cond) );
                end

                % If only v1 is in the current interval: add the distance from
                % v1 to right node
                cond = spd1 >= spdLeft(n) & spd1 < spdRight(n) & ...
                    spd2 >= spdRight(n);
                if obj.accVal(n) == obj.accVal(n+1)
                    accDistAdd(cond) = accDistAdd(cond) + ( spdRight(n).^2 - spd1(cond).^2 ) ./ ( 2 * obj.c(n) );
                else
                    accDistAdd(cond) = accDistAdd(cond) + obj.sfun( obj.b(n), obj.c(n), spd1(cond), spdRight(n) );
                end

                % If only v2 is in the current interval: add the distance from
                % the left node to v2
                cond = spd2 >= spdLeft(n) & spd2 < spdRight(n) & ...
                    spd1 < spdLeft(n);
                if obj.accVal(n) == obj.accVal(n+1)
                    accDistAdd(cond) = accDistAdd(cond) + ( spd2(cond).^2 - spdLeft(n).^2 ) ./ ( 2 * obj.c(n) );
                else
                    accDistAdd(cond) = accDistAdd(cond) + obj.sfun( obj.b(n), obj.c(n), spdLeft(n), spd2(cond) );
                end

                % If the current interval is wholly included in [v1, v2], count
                % the whole interval
                accDistAdd( spd1 < spdLeft(n) & spd2 > spdRight(n) ) = obj.accDistSeg(n);

                % If the current interval is excluded from [v1, v2], discard
                % the whole interval
                accDistAdd( spd1 > spdRight(n) | spd2 < spdLeft(n) ) = 0;

                accDist = accDist + accDistAdd;
            end

            accDist = abs(accDist);

        end

    end

end